Hierarchical Bayesian Neural Networks with variable group sizes
probabilistic machine learning
bayesian neural networks
hierarchical modeling
jax
blackjax
Author
Julius Mehringer
Published
June 11, 2025
Warning
This entry is seriously in the making. Please come back later for updates.
Bayesian Modeling is the primary choice if you want to obtain the uncertainty associated with the predictions of a model. Of course, there are also voices arguing why Bayesian Deep Learning is a promising avenue.
There are some very useful blog entries and notebooks out there (e.g. by Thomas Wiecki using Theano PyMC3 and this repo using a more recent version of JAX).
However, those examples only work with the critical assumption that the group sizes are all of the same size. In reality, this is rarely the case, of course.
Here, I will show you how you can implement a Hierarchical Bayesian Neural Network irrespective of the group sizes you observe in your dataset.
Code
from typing import Tuplefrom datetime import datefrom functools import partialfrom warnings import filterwarningsimport jaximport jax.random as jrimport jax.numpy as jnpimport equinox as eqximport numpy as npimport matplotlib as mplimport matplotlib.pyplot as pltfrom sklearn.datasets import make_moonsfrom sklearn.preprocessing import scaleimport blackjaximport tensorflow_probability.substrates.jax.distributions as tfdfrom sklearn.model_selection import train_test_splitfilterwarnings("ignore")
WARNING:2025-06-11 22:20:45,725:jax._src.xla_bridge:791: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
X, Y = make_moons(noise=noise, n_samples=1000)for i inrange(2): plt.scatter(X[Y == i, 0], X[Y == i, 1], color=cmap(float(i)), label=f"Class {i}", alpha=0.8)plt.legend()plt.show()
The standard two-moons dataset plot
Code
def rotate(X, deg): theta = np.radians(deg) c, s = np.cos(theta), np.sin(theta) R = np.matrix([[c, -s], [s, c]]) X = X.dot(R)return np.asarray(X)np.random.seed(31)Xs, Ys, gs = [], [], []Xs_train, Ys_train, gs_train, Xs_test, Ys_test, gs_test = [], [], [], [], [], []for i inrange(n_groups):# Generate data with 2 classes that are not linearly separable X, Y = make_moons(noise=noise, n_samples=n_samples[i]) X = scale(X)# Rotate the points randomly for each category rotate_by = np.random.randn() *90.0 X = rotate(X, rotate_by) Xs.append(X) Ys.append(Y) gs.append(X.shape[0]) X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=31) Xs_train.append(X_train) Ys_train.append(Y_train) gs_train.append(X_train.shape[0]) Xs_test.append(X_test) Ys_test.append(Y_test) gs_test.append(X_test.shape[0])
@eqx.filter_vmapdef make_ensemble(key):# Create an ensemble of models net = eqx.nn.MLP(in_size=data_dim, out_size=1, width_size=hidden_layer_width, depth=n_hidden_layers, key=key)return net@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))def evaluate_ensemble(ensemble, x):# Evaluate each member of the ensemble on the same data o = ensemble(x)return o.mean()def evaluate_per_ensemble(model, x):return jax.vmap(model)(x)def apply_ensemble(ensemble, D):# ensemble_fn = partial(evaluate_ensemble, ensemble) preds = eqx.filter_vmap(evaluate_per_ensemble)(ensemble, D)return predskey = jr.PRNGKey(0)hnn = make_ensemble(jr.split(key, n_groups))
Code
class NonCentredLinear(eqx.Module):# mu is the global network mu: jax.Array# eps are the n_group local networks eps: jax.Array std: jax.Arraydef__init__(self, in_size, out_size, n_groups, *, key):self.mu = jr.normal(key, (in_size, out_size))self.eps = jr.normal(key, (n_groups, in_size, out_size))self.std = jnp.ones((1,))def__call__(self, x): w =self.mu +self.std *self.epsreturn x @ wclass HNN(eqx.Module): layers: Tuple[NonCentredLinear] out: eqx.nn.Lineardef__init__(self, layer_width, n_layers, n_groups, *, key): dims = [data_dim] + [layer_width] * n_layers layers = []for n, (_in, _out) inenumerate(zip(dims[:-1], dims[1:])): layer = NonCentredLinear(_in, _out, n_groups, key=jr.fold_in(key, n)) layers += [layer]self.layers =tuple(layers)self.out = eqx.nn.Linear(layer_width, 1, key=key)def__call__(self, x):for layer inself.layers: x = layer(x) x = jax.nn.tanh(x)# Vmap over groups and samples o = jax.vmap(jax.vmap(self.out))(x)return o
def get_mean_predictions(predictions, threshold=0.5):# Compute mean prediction and confidence interval around median mean_prediction = jnp.nanmean(predictions, axis=0)return mean_prediction > threshold
def logprior_fn(params): normal = tfd.Normal(0.0, 1.0) leaves, _ = jax.tree_util.tree_flatten(params) flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])return jnp.sum(normal.log_prob(flat_params))def logprior_fn_of_hnn(params, model):"""p(w) where w is NN(X; w)""" lp =0.0 half_normal = tfd.HalfNormal(1.0) normal = tfd.Normal(0.0, 1.0)for layer in params.layers: lp += normal.log_prob(layer.mu).sum() lp += normal.log_prob(layer.eps).sum() lp += half_normal.log_prob(layer.std).sum() lp += logprior_fn(params.out)return lpdef loglikelihood_fn(params, X, Y, mask, fill_value, model):"""p(Y|Y_=NN(X; w))""" logits = jnp.ravel(apply_fn(params, X))# apply the mask: where the mask has the fill value, the logits should also be zero logits = jnp.where(jnp.ravel(mask[:, :, 0]) == fill_value, 0, logits)return jnp.sum(tfd.Bernoulli(logits).log_prob(jnp.ravel(Y)))def logdensity_fn_of_hnn(params, X, Y, mask, fill_value, model):return logprior_fn_of_hnn(params, model) + loglikelihood_fn(params, X, Y, mask, fill_value, model)